#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "hash.hpp"
#include <Python.h>
#include <numpy/arrayobject.h>

namespace py = pybind11;

#define custom_isnan(value) (!(value == value))

namespace vaex {
void init_hash_primitives_power_of_two(py::module &);
void init_hash_primitives_prime(py::module &);
void init_hash_string(py::module &);
void init_hash_object(py::module &);
} // namespace vaex

class Mask {
  public:
    Mask(size_t length) : length(length), _owns_data(true) {
        mask_data = new uint8_t[length];
        reset();
    }
    Mask(uint8_t *mask_data, size_t length)
        : mask_data(mask_data)
        , length(length)
        , _owns_data(false)
    {}

    virtual ~Mask() {
        if (_owns_data)
            delete[] mask_data;
    }
    std::pair<int64_t, int64_t> indices(int64_t i1, int64_t i2) {
        if (i2 < i1) {
            throw std::runtime_error("end index should be larger or equal to start index");
        }
        int64_t count = 0;
        int64_t start = -1, end = -1;
        for (int64_t i = 0; i < length; i++) {
            if (mask_data[i] == 1) {
                if (count == i1) {
                    start = i;
                }
                if (count == i2) {
                    end = i;
                    break;
                }
                count++;
            }
        }
        return {start, end};
    }
    int64_t raw_offset(int64_t logical_offset) {
        int64_t counted = 0;
        for (int64_t i = 0; i < length; i++) {
            if (mask_data[i] == 1) {
                counted++;
                if (counted == logical_offset) {
                    return i;
                }
            }
        }
        return -1;
    }
    void reset() {
        py::gil_scoped_release release;
        std::fill(mask_data, mask_data + length, 2);
    }
    int64_t count() {
        py::gil_scoped_release release;
        int64_t count = 0;
        for (int64_t i = 0; i < length; i++) {
            if (mask_data[i] == 1) {
                count++;
            }
        }
        return count;
    }
    int64_t is_dirty() {
        py::gil_scoped_release release;
        for (int64_t i = 0; i < length; i++) {
            if (mask_data[i] == 2) {
                return true;
            }
        }
        return false;
    }
    Mask *view(int64_t start, int64_t end) {
        if (end < start) {
            throw std::runtime_error("end index should be larger or equal to start index");
        }
        if (start < 0) {
            throw std::runtime_error("start should be >= 0");
        }
        if (end > length) {
            throw std::runtime_error("end should be <= length");
        }
        return new Mask(mask_data + start, end - start);
    }
    py::array_t<int64_t> first(int64_t amount) {
        auto ar = py::array_t<int64_t>(amount);
        auto ar_unsafe = ar.mutable_unchecked<1>();
        int64_t found = 0;
        {
            py::gil_scoped_release release;
            for (int64_t i = 0; i < length; i++) {
                if (mask_data[i] == 1) {
                    ar_unsafe(found++) = i;
                }
                if (found == amount) {
                    break;
                }
            }
        }
        auto ar_trimmed = py::array_t<int64_t>(found);
        auto ar_trimmed_unsafe = ar_trimmed.mutable_unchecked<1>();
        for (int64_t i = 0; i < found; i++) {
            ar_trimmed_unsafe(i) = ar_unsafe(i);
        }
        return ar_trimmed;
    }
    py::array_t<int64_t> last(int64_t amount) {
        auto ar = py::array_t<int64_t>(amount);
        auto ar_unsafe = ar.mutable_unchecked<1>();
        int64_t found = 0;
        {
            py::gil_scoped_release release;
            for (int64_t i = length - 1; i >= 0; i--) {
                if (mask_data[i] == 1) {
                    ar_unsafe(found++) = i;
                }
                if (found == amount) {
                    break;
                }
            }
        }
        auto ar_ordered = py::array_t<int64_t>(found);
        auto ar_ordered_unsafe = ar_ordered.mutable_unchecked<1>();
        for (int64_t i = 0; i < found; i++) {
            ar_ordered_unsafe(i) = ar_unsafe(found - 1 - i);
        }
        return ar_ordered;
    }
    uint8_t *mask_data;
    int64_t length;
    bool _owns_data;
};

template <typename T>
std::size_t hash_func(T v) {
    vaex::hash<T> h;
    return h(v);
}


class TestObject {
    public:
    // TestObject(std::string name) : name(name) {}
    TestObject(std::string name, py::memoryview bytes) : name(name), bytes(bytes) {}
    ~TestObject() {
        name = "destroyed";
    }
    std::string name;
    py::memoryview bytes;
};

class TestContainer {
    public:
    TestContainer(std::string name) : name(name) {}
    void add(std::shared_ptr<TestObject> member) {
        members.push_back(member);
    }
    std::string name;
    std::vector<std::shared_ptr<TestObject>> members;
};

int64_t find_byte(py::buffer buffer, unsigned char needle) {
    py::buffer_info info = buffer.request();
    if (info.ndim != 1) {
        throw std::runtime_error("Expected a 1d byte buffer");
    }
    // if(info.format != "O") {
    //     throw std::runtime_error("Expected an object array");
    // }
    py::gil_scoped_release release;
    unsigned char* begin = (unsigned char*)info.ptr;
    unsigned char* end = begin + info.shape[0];
    unsigned char* i = std::find(begin, end, needle);
    return i == end ? -1 : i - begin;
}

int64_t count_byte(py::buffer buffer, unsigned char needle) {
    py::buffer_info info = buffer.request();
    if (info.ndim != 1) {
        throw std::runtime_error("Expected a 1d byte buffer");
    }
    // if(info.format != "O") {
    //     throw std::runtime_error("Expected an object array");
    // }
    py::gil_scoped_release release;
    unsigned char* begin = (unsigned char*)info.ptr;
    unsigned char* end = begin + info.shape[0];
    return std::count(begin, end, needle);
}

PYBIND11_MODULE(superutils, m) {
    _import_array();

    m.doc() = "fast utils";

    py::class_<TestObject, std::shared_ptr<TestObject>>(m, "TestObject")
        // .def(py::init<std::string>())
        .def(py::init([](std::string name, py::buffer bytes) {
                return new TestObject(name, py::memoryview(bytes.request()));
            }), py::keep_alive<1, 3>())
        .def_property_readonly("name", [](const TestObject &test) { return test.name; })
        .def_property_readonly("bytes", [](const TestObject &test) { return test.bytes; })
    ;
    py::class_<TestContainer, std::shared_ptr<TestContainer>>(m, "TestContainer")
        .def(py::init<std::string>())
        .def("add", &TestContainer::add)
        .def_property_readonly("name", [](const TestContainer &test) { return test.name; })
        .def_property_readonly("members", [](const TestContainer &test) { return test.members; })
    ;

    py::class_<Mask>(m, "Mask", py::buffer_protocol())
        .def(py::init<size_t>())
        .def(py::init([](py::buffer mask_array) {
            py::buffer_info info = mask_array.request();
            if (info.ndim != 1) {
                throw std::runtime_error("Expected a 1d byte buffer");
            }
            return new Mask((uint8_t *)info.ptr, info.shape[0]);
        }))
        .def_buffer([](Mask &mask) -> py::buffer_info {
            std::vector<ssize_t> strides = {1};
            std::vector<ssize_t> shapes = {mask.length};
            return py::buffer_info((void *)mask.mask_data,                /* Pointer to buffer */
                                   sizeof(bool),                          /* Size of one scalar */
                                   py::format_descriptor<bool>::format(), /* Python struct-style format descriptor */
                                   1,                                     /* Number of dimensions */
                                   shapes,                                /* Buffer dimensions */
                                   strides);
        })
        .def_property_readonly("length", [](const Mask &mask) { return mask.length; })
        .def("indices", &Mask::indices)
        .def("raw_offset", &Mask::raw_offset)
        .def("count", &Mask::count)
        .def("first", &Mask::first)
        .def("last", &Mask::last)
        .def("reset", &Mask::reset)
        .def("is_dirty", &Mask::is_dirty)
        .def("view", &Mask::view, py::keep_alive<0, 1>())
        // .def("reduce", &Mask::reduce)
        ;

    m.def("hash", hash_func<uint64_t>);

    vaex::init_hash_primitives_power_of_two(m);
    // TODO: enable again, needs refactor of parent hash_map class
    // vaex::init_hash_primitives_prime(m);
    vaex::init_hash_string(m);
    vaex::init_hash_object(m);
    m.def("find_byte", find_byte);
    m.def("count_byte", count_byte);
}
